This notebook pulls in data on Nipah RBP entry into CHO-EFNB2 and CHO-EFNB3 cells, filters data, calculates stats, and makes figures¶

In [1]:
# this cell is tagged as parameters for `papermill` parameterization
e2_distances_file = None
func_scores_E2_file = None
func_scores_E3_file = None

filtered_E2_data = None
filtered_E3_data = None
contact_type_plot = None
E2_entry_heatmap = None
E3_entry_heatmap = None
combined_entry_contact_heatmaps = None
entry_heatmap_by_wt_aa_property = None
E2_E3_entry_corr_plot = None
E2_E3_entry_all_muts_plot = None
combined_E2_E3_correlation_plots = None
entry_region_boxplot_plot = None

nipah_config = None
altair_config = None

entropy_file = None
surface = None
In [2]:
# Parameters
func_scores_E2_file = "results/func_effects/averages/CHO_EFNB2_low_func_effects.csv"
func_scores_E3_file = "results/func_effects/averages/CHO_EFNB3_low_func_effects.csv"
e2_distances_file = "results/distances/2vsm_distances.csv"
contact_type_plot = "results/images/contact_type_plot.html"
filtered_E2_data = "results/filtered_data/E2_entry_filtered.csv"
filtered_E3_data = "results/filtered_data/E3_entry_filtered.csv"
E2_entry_heatmap = "results/images/E2_entry_heatmap.html"
E3_entry_heatmap = "results/images/E3_entry_heatmap.html"
combined_entry_contact_heatmaps = "results/images/combined_entry_contact_heatmaps.html"
entry_heatmap_by_wt_aa_property = "results/images/entry_heatmap_by_wt_aa_property.html"
E2_E3_entry_corr_plot = "results/images/E2_E3_entry_corr_plot.html"
E2_E3_entry_all_muts_plot = "results/images/E2_E3_entry_all_muts_plot.html"
combined_E2_E3_correlation_plots = (
    "results/images/combined_E2_E3_correlation_plots.html"
)
nipah_config = "nipah_config.yaml"
altair_config = "data/custom_analyses_data/theme.py"
entropy_file = "results/entropy/entropy.csv"
entry_region_boxplot_plot = "results/images/entry_region_boxplot_plot.html"
surface = "data/custom_analyses_data/surface_exposure.csv"
In [3]:
import math
import os
import re
import altair as alt

import numpy as np

import pandas as pd

import scipy.stats

import Bio.SeqIO

import yaml

from Bio import AlignIO
from Bio import PDB
from Bio.Align import PairwiseAligner
In [4]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")
Setup in correct directory
In [5]:
if surface is None:
    e2_distances_file = "results/distances/2vsm_distances.csv"
    func_scores_E2_file = "results/func_effects/averages/CHO_EFNB2_low_func_effects.csv"
    func_scores_E3_file = "results/func_effects/averages/CHO_EFNB3_low_func_effects.csv"
    
    filtered_E2_data = "results/filtered_data/E2_entry_filtered.csv"
    filtered_E3_data = "results/filtered_data/E3_entry_filtered.csv"
    
    nipah_config = "nipah_config.yaml"
    altair_config = "data/custom_analyses_data/theme.py"
    
    entropy_file = "results/entropy/entropy.csv"
    surface = "data/custom_analyses_data/surface_exposure.csv"

Read in custom altair theme and Import YAML file with parameters¶

In [6]:
if altair_config:
    with open(altair_config, 'r') as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

Filter and merge EFNB2 and EFNB3 entry¶

In [7]:
#Import median entry scores from different selections
func_scores_E2 = pd.read_csv(func_scores_E2_file).dropna().round(2) # this removes wildtype observations from df
func_scores_E3 = pd.read_csv(func_scores_E3_file).dropna().round(2) # this removes wildtype observations from df
In [8]:
def num_selections_and_filter(df, name):
    def calculate_filtering_mutants(df, name):
        print(f"\nThe dataset is: {name}")
        
        # How many selections were performed
        max_sels = df["n_selections"].max()
        num_sels_cutoff = (max_sels / 2) + 1

        # Find total number of possible mutants
        total_mut = (602 - 71) * 19
        
        # Filter data
        filter_test = df[
            (df["site"] != 603) & (df["mutant"] != "-") & (df["mutant"] != "*")
        ]

        num_variants_pre_filter = filter_test.shape[0]
        print(
            f"After filtering stop and gaps, there are {num_variants_pre_filter} mutants which is {(num_variants_pre_filter/total_mut) * 100:.1f}%"
        )

        filter_test_times_seen = filter_test[
            filter_test["times_seen"] >= config["func_times_seen_cutoff"]
        ]
        num_variants_times_seen = filter_test_times_seen.shape[0]
        print(
            f"After filtering for {config['func_times_seen_cutoff']} times seen, there are {num_variants_times_seen}, which is {(num_variants_times_seen/total_mut) * 100:.1f}%"
        )

        filter_test_effect_std = filter_test_times_seen[
            filter_test_times_seen["effect_std"] <= config["func_std_cutoff"]
        ]
        num_variants_std = filter_test_effect_std.shape[0]
        print(
            f"After filtering for {config['func_std_cutoff']} std cutoff, there are {num_variants_std}, which is {(num_variants_std/total_mut) * 100 :.1f}%"
        )

        filter_test_n_selections = filter_test_effect_std[
            filter_test_effect_std["n_selections"] >= num_sels_cutoff
        ]
        num_variants_n_selections = filter_test_n_selections.shape[0]
        print(
            f"After filtering for mutants in in all selections, there are {num_variants_n_selections}, which is {(num_variants_n_selections/total_mut) * 100 :.1f}%"
        )

    def apply_filters(df, name):
        # Now do the main filtering
        max_sels = df["n_selections"].max()
        num_sels_cutoff = (max_sels / 2) + 1
        print(
            f"The number of selections a mutant must be observed in is: {num_sels_cutoff}"
        )
        # The main filtering. Filters site 603 (is a stop codon/end of gene and we don't want those mutants). Also filter out stop mutants and apply filtering from config file
        filtered_df = df[
            (df["site"] != 603)
            & (df["mutant"] != "-")
            & (df["mutant"] != "*")
            & (df["times_seen"] >= config["func_times_seen_cutoff"])
            & (df["effect_std"] <= config["func_std_cutoff"])
            & (df["n_selections"] >= num_sels_cutoff)
        ]
        return filtered_df

    # Filtering stats
    calculate_filtering_mutants(df, name)  # call definition above

    # Filter
    filtered_df = apply_filters(df, name)

    # Now write filtered_data to results/
    if name == "E2":
        filtered_df.to_csv(filtered_E2_data, index=False)
    if name == "E3":
        filtered_df.to_csv(filtered_E3_data, index=False)

    # return filtered dataframe
    return filtered_df

# Call the filtering functions
func_scores_E2 = num_selections_and_filter(func_scores_E2, "E2")
func_scores_E3 = num_selections_and_filter(func_scores_E3, "E3")

# make a merged dataframe of ephrin-b2 and ephrin-b3 entry data
def merge_data(df1,df2):
    merged_df = pd.merge(
        df1,
        df2,
        on=["site", "mutant", "wildtype"],
        how="outer",
        suffixes=["_E2", "_E3"],
    )
    df1["cell_type"] = "CHO-EFNB2"
    df2["cell_type"] = "CHO-EFNB3"
    concat_df = pd.concat([df1,df2])

    return merged_df,concat_df

merged_df,concat_df = merge_data(func_scores_E2,func_scores_E3)

# Show some stats of filtered merged data
stats = merged_df.describe().round(1)
display(stats)
The dataset is: E2
After filtering stop and gaps, there are 10040 mutants which is 99.5%
After filtering for 2 times seen, there are 9789, which is 97.0%
After filtering for 1 std cutoff, there are 9733, which is 96.5%
After filtering for mutants in in all selections, there are 9729, which is 96.4%
The number of selections a mutant must be observed in is: 5.0

The dataset is: E3
After filtering stop and gaps, there are 10072 mutants which is 99.8%
After filtering for 2 times seen, there are 9852, which is 97.7%
After filtering for 1 std cutoff, there are 9718, which is 96.3%
After filtering for mutants in in all selections, there are 9589, which is 95.0%
The number of selections a mutant must be observed in is: 4.5
site effect_E2 effect_std_E2 times_seen_E2 n_selections_E2 effect_E3 effect_std_E3 times_seen_E3 n_selections_E3
count 9886.0 9729.0 9729.0 9729.0 9729.0 9589.0 9589.0 9589.0 9589.0
mean 337.2 -0.9 0.3 7.2 8.0 -1.0 0.3 6.1 7.0
std 153.2 1.4 0.2 4.1 0.1 1.5 0.2 3.3 0.1
min 71.0 -3.5 0.0 2.0 5.0 -3.6 0.0 2.0 5.0
25% 205.0 -2.2 0.2 4.6 8.0 -2.5 0.2 4.1 7.0
50% 338.0 -0.3 0.3 6.4 8.0 -0.2 0.3 5.4 7.0
75% 470.0 0.2 0.5 8.5 8.0 0.2 0.5 7.1 7.0
max 602.0 0.6 1.0 64.4 8.0 0.7 1.0 49.1 7.0

Stats¶

In [9]:
def calculate_stats(df, name):
    print(f"For {name}:")
    total_mut = (602 - 71) * 19
    print(f'There are {total_mut} amino acid mutations possible')
    muts_present = df["effect"].shape[0]
    fraction_muts = muts_present / total_mut
    print(
        f"fraction muts present in {name} is {fraction_muts:.2f} {muts_present}/{total_mut}"
    )

    deleterious_muts = df[df["effect"] <= -0.25].shape[0]
    neutral_muts = df[(df["effect"] > -0.25) & (df["effect"] < 0.25)].shape[0]
    positive_muts = df[df["effect"] > 0.25].shape[0]

    frac_bad_muts = deleterious_muts / muts_present
    frac_neutral_muts = neutral_muts / muts_present
    frac_pos_muts = positive_muts / muts_present
    print(
        f"The number of deleterious mutants for {name} is {frac_bad_muts:.2f} {deleterious_muts}/{muts_present}"
    )
    print(
        f"The number of neutral mutants for {name} is {frac_neutral_muts:.2f} {neutral_muts}/{muts_present}"
    )
    print(
        f"The number of positive mutants for {name} is {frac_pos_muts:.2f} {positive_muts}/{muts_present}\n"
    )

calculate_stats(func_scores_E2, "E2")
calculate_stats(func_scores_E3, "E3")
For E2:
There are 10089 amino acid mutations possible
fraction muts present in E2 is 0.96 9729/10089
The number of deleterious mutants for E2 is 0.51 4929/9729
The number of neutral mutants for E2 is 0.26 2529/9729
The number of positive mutants for E2 is 0.23 2213/9729

For E3:
There are 10089 amino acid mutations possible
fraction muts present in E3 is 0.95 9589/10089
The number of deleterious mutants for E3 is 0.49 4741/9589
The number of neutral mutants for E3 is 0.27 2621/9589
The number of positive mutants for E3 is 0.22 2147/9589

How many sites and which sites only have negative entry scores for mutations?¶

In [10]:
def overall_stats_all_neg(df,effect):
    filtered_df = df.groupby('site').filter(lambda group: (group[effect] < 0).all())
    unique = filtered_df['site'].unique()
    print(list(unique))
    total_sites = df['site'].unique().shape[0]
    subset = filtered_df['site'].unique().shape[0]
       
    fraction = subset/total_sites
    percent = fraction * 100
    print(f'The total number of sites are: {total_sites}')
    print(f' The number of sites where all mutants are negative for {effect}: {subset}')
    print(f' The percent of sites where all mutants are negative for {effect}: {percent:.0f}')
    return unique

intolerant_sites_E2 = list(overall_stats_all_neg(func_scores_E2,'effect'))
intolerant_sites_E3 = list(overall_stats_all_neg(func_scores_E3,'effect'))
[106, 107, 108, 111, 112, 113, 120, 121, 125, 126, 127, 130, 138, 146, 151, 157, 158, 159, 162, 163, 165, 167, 172, 189, 203, 205, 206, 207, 208, 216, 229, 240, 246, 251, 253, 254, 257, 258, 259, 260, 262, 263, 264, 266, 267, 303, 323, 331, 347, 355, 382, 387, 395, 412, 460, 467, 487, 489, 493, 499, 500, 503, 506, 537, 563, 565, 574, 594, 598]
The total number of sites are: 532
 The number of sites where all mutants are negative for effect: 69
 The percent of sites where all mutants are negative for effect: 13
[95, 100, 106, 107, 108, 111, 112, 113, 121, 125, 126, 138, 146, 158, 162, 163, 201, 203, 206, 207, 216, 229, 240, 243, 248, 251, 253, 257, 258, 259, 260, 263, 266, 303, 347, 352, 355, 368, 382, 387, 389, 395, 412, 439, 458, 460, 467, 486, 487, 489, 493, 494, 497, 499, 500, 501, 503, 504, 505, 506, 510, 526, 531, 532, 533, 537, 556, 557, 563, 565, 573, 574, 579, 581, 584, 585, 588, 594]
The total number of sites are: 532
 The number of sites where all mutants are negative for effect: 78
 The percent of sites where all mutants are negative for effect: 15
In [11]:
def calculate_top_func_scores(df,effect):
    percentile_95_effect_E2 = df[effect].quantile(0.999)
    cutoff_E2_df_sites = df[df[effect] > percentile_95_effect_E2]
    E2_site_cutoff = cutoff_E2_df_sites['site'].unique()
    print(f'The sites with the highest functional scores for {effect} are: {list(E2_site_cutoff)}')

calculate_top_func_scores(func_scores_E2,'effect')
calculate_top_func_scores(func_scores_E3,'effect')
The sites with the highest functional scores for effect are: [280, 305, 407, 463, 468, 501, 552, 584]
The sites with the highest functional scores for effect are: [183, 315, 337, 358, 378, 393, 418, 419, 554, 597]

Make bubble plots of receptor contact site type (median values per site)¶

In [12]:
def make_bubbleplot_entry_region(df):  # Create a bubble plot using Altair for contact site mutants
    barrel_ranges = {
        "Hydrophobic": config["hydrophobic"],
        "Salt Bridges": config["salt_bridges"],
        "Hydrogen Bonds": config["h_bond_total"],
        "Contact": config["contact_sites"],
        "Overall": list(range(71, 602)),
    }
    custom_order = [
        "Hydrophobic",
        "Salt Bridges",
        "Hydrogen Bonds",
        "Contact",
        "Overall",
    ]
    empty_chart = []
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    for selection in ["CHO-EFNB2", "CHO-EFNB3"]:
        agg_means = []
        tmp_df = df[df["cell_type"] == selection]
        mean_df = tmp_df.groupby("site")[["effect"]].median().reset_index()

        # For each barrel, filter the site_means dataframe to the sites belonging to that barrel and then store the means
        for barrel, sites in barrel_ranges.items():
            subset = mean_df[mean_df["site"].isin(sites)]
            for _, row in subset.iterrows():
                agg_means.append(
                    {"barrel": barrel, "effect": row["effect"], "site": row["site"]}
                )
        agg_means_df = pd.DataFrame(agg_means)
        chart = (
            alt.Chart(agg_means_df, title=f"{selection}")
            .mark_point(color="black", filled=True)
            .encode(
                x=alt.X(
                    "barrel:O",
                    sort=custom_order,
                    title='Contact Type',
                    axis=alt.Axis(labelAngle=-90),
                ),
                y=alt.Y(
                    "effect",
                    title="Median Cell Entry",
                    axis=alt.Axis(grid=True, tickCount=4),
                ),
                xOffset="random:Q",
                tooltip=["barrel", "effect", "site"],
                size=alt.condition(
                    variant_selector, alt.value(100), alt.value(25)
                ),
                color=alt.condition(
                    variant_selector, alt.value("orange"), alt.value("black")
                ),
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.3)),
            )
            .transform_calculate(random="sqrt(-1*log(random()))*cos(2*PI*random())")
            .properties(
                height=200,
                width=200,
            )
        )
        empty_chart.append(chart)
    combined_effect_chart = (
        alt.hconcat(*empty_chart)
        .resolve_scale(y="shared", x="shared", color="independent")
        .add_params(variant_selector)
    )
    return combined_effect_chart


tmp_img = make_bubbleplot_entry_region(concat_df)
tmp_img.display()
tmp_img.save(contact_type_plot)

Make bubble plots depending on amino acid property¶

In [13]:
def make_bubbleplot_wildtype_prop(df):
    barrel_ranges = {
        "hydrophobic": list(["A", "V", "L", "I", "M"]),
        "aromatic": list(["Y", "W", "F"]),
        "positive": list(["K", "R", "H"]),
        "negative": list(["E", "D"]),
        "hydrophilic": list(["S", "T", "N", "Q"]),
        "special": list(["C", "P", "G"]),
    }
    empty_charts = []
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    for selection in ["CHO-EFNB2", "CHO-EFNB3"]:
        if selection == "CHO-EFNB2":
            effect_name = "EFNB2"
        else:
            effect_name = "EFNB3"

        tmp_df = df[df["cell_type"] == selection]

        unique_wildtype_df = tmp_df[["site", "wildtype"]].drop_duplicates()
        mean_df = tmp_df.groupby("site")[["effect"]].median().reset_index()
        mean_df = pd.merge(mean_df, unique_wildtype_df, on="site", how="left")

        agg_means = []

        # For each barrel, filter the site_means dataframe to the sites belonging to that barrel and then store the means
        for barrel, sites in barrel_ranges.items():
            subset = mean_df[mean_df["wildtype"].isin(sites)]
            for _, row in subset.iterrows():
                agg_means.append(
                    {"wildtype_class": barrel, "effect": row["effect"], "site": row["site"], "wildtype": row["wildtype"]}
                )
        agg_means_df = pd.DataFrame(agg_means)

        chart = (
            alt.Chart(agg_means_df, title=f"{selection}")
            .mark_point(filled=True)
            .encode(
                x=alt.X(
                    "wildtype_class:O",
                    title="Wildtype amino acid property",
                    axis=alt.Axis(labelAngle=-90),
                ),  # sort=custom_order
                y=alt.Y(
                    "effect",
                    title=f"Median Cell Entry",
                    axis=alt.Axis(grid=True, tickCount=4),
                ),
                xOffset="random:Q",
                # color = alt.Color('barrel').legend(None),
                tooltip=["wildtype_class", "effect", "site","wildtype"],
                size=alt.condition(variant_selector, alt.value(100), alt.value(25)),
                color=alt.condition(
                    variant_selector, alt.value("orange"), alt.value("black")
                ),
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.3)),
            )
            .transform_calculate(random="sqrt(-1*log(random()))*cos(2*PI*random())")
            .properties(height=200, width=200)
        )
        empty_charts.append(chart)
    combined_effect_chart = (
        alt.hconcat(*empty_charts)
        .resolve_scale(y="shared", x="shared", color="independent")
        .add_params(variant_selector)
    )
    return combined_effect_chart


wildtype_aa_bubble_img = make_bubbleplot_wildtype_prop(concat_df)
wildtype_aa_bubble_img.display()

Plot correlations between E2 and E3 entry¶

In [14]:
# Import distance data
e2_distances = pd.read_csv(e2_distances_file)
distance_df = pd.merge(
    merged_df, e2_distances[["site", "distance"]], on="site", how="left"
)


def determine_distance(df):
    # Define the conditions
    conditions = [
        df["distance"] < 4,
        (df["distance"] >= 4) & (df["distance"] <= 8),
        df["distance"] > 8,
    ]

    # Define the associated values for the conditions
    choices = ["contact", "close", "distant"]

    # Apply the conditions and choices to the 'E2_contact' column
    df["contact"] = np.select(conditions, choices, default="distant")
    return df


distance_df = determine_distance(distance_df)


def median_correlation_plot(df, metric):
    aggregation = getattr(df.groupby("site")[["effect_E2", "effect_E3"]], metric)
    means = aggregation().reset_index()
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
        means["effect_E2"], means["effect_E3"]
    )
    display(r_value.round(2))

    means = means.rename(
        columns={"effect_E2": f"{metric}_E2", "effect_E3": f"{metric}_E3"}
    )

    contact_sites = df[["site", "contact", "wildtype"]].drop_duplicates()
    df_mean = pd.merge(means, contact_sites, on="site", how="left")

    chart = (
        alt.Chart(df_mean)
        .mark_point(size=40, opacity=0.5, filled=True)
        .encode(
            x=alt.X(f"{metric}_E2", title="Summed Cell Entry for EFNB2"),
            y=alt.Y(f"{metric}_E3", title="Summed Cell Entry for EFNB3"),
            tooltip=["site", "wildtype"],
            color=alt.Color(
                "contact",
                scale=alt.Scale(
                    domain=["contact", "close", "distant"],
                    range=["#1f4e79", "#ff7f0e", "gray"],
                ),
                legend=alt.Legend(title="RBP Distance to Receptor"),
            ),
        )
        .properties(width=200, height=200)
    )
    min_ = int(df_mean[f"{metric}_E2"].min())
    max_ = int(df_mean[f"{metric}_E3"].max())
    text = (
        alt.Chart({"values": [{"x": min_, "y": max_, "text": f"r = {r_value:.2f}"}]})
        .mark_text(
            align="left",
            baseline="top",
            dx=0,  # Adjust this for position
            dy=0,  # Adjust this for position
        )
        .encode(x=alt.X("x:Q"), y=alt.Y("y:Q"), text="text:N")
    )
    plot = chart + text
    return plot


E2_E3_plot = median_correlation_plot(distance_df, "sum")
E2_E3_plot.display()
E2_E3_plot.save(E2_E3_entry_corr_plot)


def correlation_plot(df):
    slider = alt.binding_range(min=2, max=25, step=1, name="times_seen")
    selector = alt.param(name="SelectorName", value=2, bind=slider)
    chart = (
        alt.Chart(df)
        .mark_point(size=30, filled=True)
        .encode(
            x=alt.X("effect_E2", title="Cell Entry for EFNB2"),
            y=alt.Y("effect_E3", title="Cell Entry for EFNB3"),
            tooltip=["site", "wildtype", "mutant"],
            opacity=alt.condition(
                alt.datum.times_seen_E2 < selector, alt.value(0), alt.value(0.2)
            ),
            color=alt.Color(
                "contact",
                scale=alt.Scale(
                    domain=["contact", "close", "distant"],
                    range=["#1f4e79", "#ff7f0e", "gray"],
                ),
                legend=alt.Legend(title="RBP Distance to Receptor"),
            ),
        )
        .properties(width=200, height=200)
        .add_params(selector)
    )
    min_ = int(df["effect_E2"].min())
    max_ = int(df["effect_E3"].max())

    return chart


tmp_img_corr = correlation_plot(distance_df)
tmp_img_corr.display()
tmp_img_corr.save(E2_E3_entry_all_muts_plot)


(E2_E3_plot | tmp_img_corr).save(combined_E2_E3_correlation_plots)
0.81

Make boxplot showing median entry by RBP region¶

In [15]:
def make_boxplot_entry_region(df):
    barrel_ranges = {
        "Stalk": list(range(96, 147)),
        "Neck": list(range(148, 165)),
        "Linker": list(range(166, 177)),
        "Head": list(range(178, 602)),
        'Receptor Contact': config['contact_sites'],
        "Total": list(range(71, 602)),
    }
    custom_order = ["Stalk", "Neck", "Linker", "Head", "Receptor Contact", "Total"]
    empty_charts = []
    for selection in ["CHO-EFNB2", "CHO-EFNB3"]:
        if selection == "CHO-EFNB2":
            effect_name = "EFNB2"
        else:
            effect_name = "EFNB3"

        tmp_df = df[df["cell_type"] == selection]
        agg_means = []

        # For each barrel, filter the site_means dataframe to the sites belonging to that barrel and then store the means
        for barrel, sites in barrel_ranges.items():
            subset = tmp_df[tmp_df["site"].isin(sites)]
            for _, row in subset.iterrows():
                agg_means.append(
                    {"region": barrel, "effect": row["effect"], "site": row["site"]}
                )
            agg_means_df = pd.DataFrame(agg_means)

        chart = (
            alt.Chart(agg_means_df, title=f"{selection}")
            .mark_boxplot(color="darkgray", extent="min-max", opacity=1)
            .encode(
                x=alt.X(
                    "region:O",
                    sort=custom_order,
                    title="RBP Region",
                    axis=alt.Axis(labelAngle=-90),
                ),
                y=alt.Y(
                    "effect",
                    title=f"Cell Entry",
                    axis=alt.Axis(grid=True, tickCount=4),
                ),
                tooltip=["region", "effect", "site"],
            )
            .properties(height=200, width=200)
        )
        empty_charts.append(chart)
    combined_effect_chart = alt.hconcat(*empty_charts).resolve_scale(
        y="shared", x="shared", color="independent"
    )
    return combined_effect_chart


entry_region_boxplot = make_boxplot_entry_region(concat_df)
entry_region_boxplot.display()
entry_region_boxplot.save(entry_region_boxplot_plot)

Generate Full Heatmap For EFNB2 and EFNB3¶

Define and prep the heatmaps. This consists of making separate dataframes for entropy, contact, and empty dataframes of all possible amino acid mutations. Then combining them.

In [16]:
def prepare_entropy():  # need to prepare entropy data for plotting on heatmap
    # read in entropy data, calculated in different notebook
    entropy = pd.read_csv(entropy_file)
    df = entropy[["site", "henipavirus_entropy"]]
    df = df.dropna(subset=["site"])
    df["site"] = df["site"].astype("Int64")
    df = df.rename(columns={"henipavirus_entropy": "entropy"})
    df['entropy'] = df['entropy'].round(2)
    df = df[["site", "entropy"]].drop_duplicates()
    df["mutant"] = "entropy"
    df["wildtype"] = ""
    df["type"] = "entropy"
    df.rename(columns={"entropy": 'value'}, inplace=True)
    return df

def make_contact():
    df = pd.DataFrame(
        {
            "site": config["contact_sites"],
            "contact": [0.0] * len(config["contact_sites"]),
        }
    )
    # Renaming and restructuring the dataframe as per your original function
    df["mutant"] = "contact"
    df["wildtype"] = ""
    df["type"] = "contact"
    df.rename(columns={"contact": 'value'}, inplace=True)
    return df

# This gets called during heatmap generation
def make_empty_df(df, contact_df=None, entropy_df=None, contact_flag=None, entropy_flag=None):
    sites = range(71, 603)
    amino_acids = ["R", "K", "H", "D", "E", "Q", "N", "S", "T", "Y", "W", "F", "A", "I", "L", "M", "V", "G", "P", "C"]
    # Create the combination of each site with each amino acid
    data = [{"site": site, "mutant": aa} for site in sites for aa in amino_acids]

    # Create the DataFrame
    empty_df = pd.DataFrame(data)
    all_sites_df = pd.merge(empty_df, df, on=["site", "mutant"], how="left")

    df_test = all_sites_df.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["effect"],
        var_name="type",
        value_name='value',
    )
    if contact_flag and entropy_flag is None:
        df_test = pd.concat([df_test], ignore_index=True)
    if contact_flag is True:
        df_test = pd.concat([df_test, contact_df], ignore_index=True)
    if entropy_flag is True:
        df_test = pd.concat([df_test, entropy_df], ignore_index=True)
    if entropy_flag and contact_flag is True:
        df_test = pd.concat([df_test, entropy_df, contact_df], ignore_index=True)


    return df_test

Next define how the different heatmaps will be made. These are separate heatmaps that get combined

In [17]:
# Make the base heatmap. This contains information about the x_axis and heatmap_sites which are important for sorting them correctly. 
def make_base_heatmap(df, heatmap_sites, x_axis):
    base = (
        alt.Chart(df)
        .encode(
            x=alt.X("site:O", title="Site", sort=heatmap_sites, axis=x_axis),
            y=alt.Y(
                "mutant",
                title="Amino Acid",
                sort=alt.EncodingSortField(field='mutant_rank', order='ascending'),
                axis=alt.Axis(grid=False),
            ),
        )
        .properties(
            width=alt.Step(10),
            height=alt.Step(11),
        )
    )
    return base

# This makes an 'empty' heatmap that shows sites that were not observed as some color (default:gray)
def make_empty_heatmap(base, background_color):
    chart_empty = (
        base.mark_rect(color=background_color)
        .encode(
            tooltip=['site', 'mutant']
        )
        .transform_filter(
            (alt.datum.type == "effect") & (alt.datum.value == None) 
        )
    )
    return chart_empty
# This makes the white squares and X for the wildtype amino acids
def make_wildtype_heatmap(unique_wildtypes_df, strokewidth_size, heatmap_sites):
    wildtype_layer_box = (
        alt.Chart(unique_wildtypes_df)
        .mark_rect(color="white", stroke="black", strokeWidth=strokewidth_size)
        .encode(
            x=alt.X("site:O", sort=heatmap_sites),
            y=alt.Y("wildtype", sort=alt.EncodingSortField(field="mutant_rank", order="ascending")),
            tooltip=["site", "wildtype"],
        )
        .transform_filter(
            (alt.datum.type == "effect") & (alt.datum.wildtype != None) & (alt.datum.value != None)
        )
    )
    wildtype_layer = (
        alt.Chart(unique_wildtypes_df)
        .mark_text(color="black", text="X", size=8, align="center", baseline="middle")
        .encode(
            x=alt.X("site:O", sort=heatmap_sites),
            y=alt.Y("wildtype", sort=alt.EncodingSortField(field="mutant_rank", order="ascending")),
            tooltip=["site", "wildtype"],
        )
        .transform_filter(
            (alt.datum.type == "effect") & (alt.datum.wildtype != None) & (alt.datum.value != None)
        )
    )
    return wildtype_layer_box, wildtype_layer

# This makes the actual effect heatmap, and adds a bar for the legend if its the first time through the loop
def create_effect_chart(base, color_scale_effect, strokewidth_size, legend_title=None, effect_legend_added=None):
    legend = alt.Legend(title=legend_title) if effect_legend_added is True else None
    chart = (
        base.mark_rect(stroke="black", strokeWidth=strokewidth_size)
        .encode(
            color=alt.condition(
                'datum.type == "effect"',
                alt.Color('value:Q', scale=color_scale_effect, legend=legend),
                alt.value("transparent"),
            ),
            tooltip=['site', 'mutant', 'wildtype', 'value']
        )
        .transform_filter(
            (alt.datum.wildtype != '') & (alt.datum.wildtype != None)
        )
    )
    return chart
# This makes a chart for the entropy values 
def create_entropy_chart(base, color_scale_entropy, strokewidth_size, legend_title=None, entropy_legend_added=None):
    legend = alt.Legend(title='Henipavirus Entropy') if entropy_legend_added is True else None
    chart = (
        base.mark_rect(stroke="black", strokeWidth=strokewidth_size)
        .encode(
            color=alt.condition(
                'datum.mutant == "entropy"',
                alt.Color('value:Q', scale=color_scale_entropy, legend=legend),
                alt.value("transparent"),
            ),
            tooltip=['site', 'mutant', 'wildtype', 'value']
        )
    )
    return chart
# This makes a chart for the contact sites
def create_contact_chart(base):
    chart_contact = (
        base.mark_rect(color="black")
        .encode(tooltip=['site'])
        .transform_filter(
            (alt.datum.mutant == "contact")
        )
    )
    return chart_contact

# This compiles all the different charts and returns a single chart
def compile_chart(df, heatmap_sites, unique_wildtypes_df, x_axis, background_color, color_scale_effect, color_scale_entropy, strokewidth_size=None, legend_title=None, effect_legend_added=None, entropy_legend_added=None):
    base = make_base_heatmap(df, heatmap_sites, x_axis)
    chart_empty = make_empty_heatmap(base, background_color)
    chart_contact = create_contact_chart(base)
    chart_effect = create_effect_chart(base, color_scale_effect, strokewidth_size, legend_title, effect_legend_added)
    chart_entropy = create_entropy_chart(base, color_scale_entropy, strokewidth_size, legend_title, entropy_legend_added)
    wildtype_layer_box, wildtype_layer = make_wildtype_heatmap(unique_wildtypes_df, strokewidth_size, heatmap_sites)
    
    chart = alt.layer(
        chart_empty,
        chart_effect,
        chart_entropy,
        chart_contact,
        wildtype_layer_box,
        wildtype_layer,

    ).resolve_scale(y="shared", x="shared", color="independent")
    
    return chart
In [18]:
def plot_entry_heatmap(
    df, 
    legend_title, 
    null_color=None, 
    ranges=None, 
    effect_color=None, 
    entropy_color=None,
    strokewidth_size=None,
    custom_y_axis_order=None,
    entropy_flag=None,
    contact_flag=None,
    specific_sites=None,
    specific_sites_name=None):
    """
    Generates a customizable heatmap for deep mutational scanning (DMS) data visualization.

    Parameters:
    - df (DataFrame): The data frame containing the data to be visualized. It must include the columns 'site', 'mutant', 'value', and 'wildtype'.
    - legend_title (str): The title of the heatmap legend.
    - null_color (str, optional): Color for mutants with no observations. Default is 'gray'.
    - ranges (list of tuples, optional): Defines the ranges for site wrapping on the heatmap. If not provided, a default range is used.
    - effect_color (str, optional): Color scheme for effect values. Default is 'red-blue'.
    - entropy_color (str, optional): Color scheme for entropy values. Default is 'purples'.
    - strokewidth_size (float, optional): The width of the stroke used in the heatmap. Default size is not specified.
    - custom_y_axis_order (list, optional): Specifies a custom order for the y-axis, overriding the default amino acid order.
    - entropy_flag (bool, optional): If True, sequence entropy is included in the heatmap. Default is False.
    - contact_flag (bool, optional): If True, contact sites are included in the heatmap. Default is False.
    - specific_sites (list, optional): Specifies a subset of sites to be plotted. If None, all sites are plotted using wrapping. Default is None.
    - specific_sites_name (str, optional): A title to display at the top of the heatmap for specific sites. Default is None.

    Returns:
    An Altair chart object representing the generated heatmap. This chart can be further customized or directly displayed in Jupyter notebooks or other compatible environments.
    """
    
    if contact_flag:
        contact_df = make_contact()
    else: 
        contact_df = None
    if entropy_flag is True:
        entropy_df = prepare_entropy()
    else:
        entropy_df = None

    # Make the dataframes for plotting.
    empty_df = make_empty_df(df,contact_df=contact_df,entropy_df=entropy_df,contact_flag=contact_flag,entropy_flag=entropy_flag)

    # Define the base order list
    base_order = ["R", "K", "H", "D", "E", "Q", "N", "S", "T", "Y", "W", "F", "A", "I", "L", "M", "V", "G", "P", "C"]
    
    # Initialize custom_order with custom_y_axis_order or base_order based on custom_y_axis_order's value
    custom_order = custom_y_axis_order if custom_y_axis_order is not None else base_order
    # Prepend conditions based on flags
    if entropy_flag and contact_flag:
        # Both flags are true, prepend both "contact" and "entropy"
        custom_order = ["contact", "entropy"] + custom_order
    elif entropy_flag:
        # Only entropy_flag is true, prepend "entropy"
        custom_order = ["entropy"] + custom_order
    elif contact_flag:
        # Only contact_flag is true, prepend "contact"
        custom_order = ["contact"] + custom_order
    
    # Optional parameters
    if null_color is None:
        background_color = "#d1d3d4"
    else:
        background_color = null_color
        
    # Sites for wrapping heatmap correctly
    if ranges is None:
        full_ranges = [
            list(range(start, end))
            for start, end in [(71, 204), (204, 337), (337, 470), (470, 603)]
        ]
    else:
        full_ranges=ranges
    
    # effect_color
    if effect_color is None:
        color_scale_effect = alt.Scale(scheme="redblue", domainMid=0, domain=[-4, 2.5])
    else:
        color_scale_effect = alt.Scale(scheme=effect_color, domainMid=0, domain=[-4, 2.5])
    
    # entropy_color
    if entropy_color is None:
        color_scale_entropy = alt.Scale(scheme="purples", domain=[0, 2], reverse=True)
    else:
        color_scale_entropy = alt.Scale(scheme=entropy_color, domain=[0, 2], reverse=True)
    
    # strokewidth size
    if strokewidth_size is None:
        strokewidth_size = 0.25
    else:
        strokewidth_size = strokewidth_size

    if entropy_flag is None:
        entropy_legend_added = None
    else:
        entropy_legend_added = True
      
    effect_legend_added = True

    def determine_sorting_order(df):
        # Sort the dataframe by 'site' to ensure that duplicates are detected correctly.
        final_df = df.sort_values("site")
        sort_order = {mutant: i for i, mutant in enumerate(custom_order)}
        final_df["mutant_rank"] = final_df["mutant"].map(sort_order)
        # Map the 'mutant' column to these ranks
        # Now sort the dataframe by this rank
        final_df = final_df.sort_values("mutant_rank")
        sites = sorted(final_df["site"].unique(), key=lambda x: float(x))
        return final_df, sites, sort_order
    
    heatmap_df, heatmap_sites,sort_order = determine_sorting_order(empty_df)
    
    # container to hold the charts
    charts = []

    if specific_sites:
        #Filter the heatmap to only show certain sites
        subset_df = heatmap_df[heatmap_df["site"].isin(specific_sites)]
        
        ### Need to do independently for wildtype here for individual sites
        unique_wildtypes_df = subset_df.drop_duplicates(subset=["site","wildtype"])  
        unique_wildtypes_df = unique_wildtypes_df.sort_values("site")
        sort_order = {mutant: i for i, mutant in enumerate(custom_order)}
        unique_wildtypes_df["mutant_rank"] = unique_wildtypes_df["wildtype"].map(sort_order)
        unique_wildtypes_df = unique_wildtypes_df.sort_values("mutant_rank")

        #Setup x-axis labeling
        x_axis = alt.Axis(
                labelAngle=-90,
                title="Site",
                labels=True,
        )  
        # Run the main heatmap compiler function
        chart = compile_chart(subset_df, heatmap_sites, unique_wildtypes_df, x_axis, background_color, color_scale_effect=color_scale_effect, color_scale_entropy=color_scale_entropy,strokewidth_size=strokewidth_size, legend_title=legend_title, effect_legend_added=effect_legend_added, entropy_legend_added=entropy_legend_added)
        #Since this is a single chart, I don't know why I need to do this, but I seem to get errors if I don't append and then do alt.vconcat below. I get why I need to do this for multiple heatmaps in a for loop, but not here. Leaving in.
        charts.append(chart)
        if specific_sites_name:
            specific_sites_name=specific_sites_name
        else:
            specific_sites_name=''
        combined_charts = alt.vconcat(*charts,title=specific_sites_name).resolve_scale(y="shared", x="shared", color="shared")
        return combined_charts
    else:
        for idx, subset in enumerate(full_ranges):
            # Flags for showing the legend only the first time
            subset_df = heatmap_df[
                heatmap_df["site"].isin(subset)
            ]  # for the wrapping of sites
            unique_wildtypes_df = subset_df.drop_duplicates(
                subset=["site", "wildtype"]
            )  # for the wildtype mapping
            
            # Keep track of where in the loop we are for plotting
            is_last_plot = idx == len(full_ranges) - 1
            x_axis = alt.Axis(
                labelAngle=-90,
                labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                title="Site" if is_last_plot else None,
                labels=True,
            )
            chart = compile_chart(subset_df, heatmap_sites, unique_wildtypes_df, x_axis, background_color, color_scale_effect=color_scale_effect, color_scale_entropy=color_scale_entropy,strokewidth_size=strokewidth_size, legend_title=legend_title, effect_legend_added=effect_legend_added, entropy_legend_added=entropy_legend_added)
            charts.append(chart)
            effect_legend_added = None
            entropy_legend_added = None
        combined_chart = alt.vconcat(
            *charts, spacing=3, title=f"{legend_title}"
        ).resolve_scale(y="shared", x="independent", color="shared")
        return combined_chart
In [19]:
E2_entry_heatmap_full = plot_entry_heatmap(
    df = func_scores_E2, 
    legend_title = "CHO-EFNB2 Entry", 
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    contact_flag = True,
    entropy_flag = True,
)
E2_entry_heatmap_full.display()
E2_entry_heatmap_full.save(E2_entry_heatmap)
In [20]:
E3_entry_heatmap_full = plot_entry_heatmap(
    df = func_scores_E3, 
    legend_title = "CHO-EFNB3 Entry", 
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    contact_flag = True,
    entropy_flag = True
)
E3_entry_heatmap_full.display()
E3_entry_heatmap_full.save(E3_entry_heatmap)
In [21]:
E2_entry_heatmap_contact = plot_entry_heatmap(
    df = func_scores_E2, 
    legend_title = "CHO-EFNB2 Entry", 
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    specific_sites=config['contact_sites'],
    #specific_sites_name='Contact Sites',
    #contact_flag = False,
    #entropy_flag = True,
)
#E2_entry_heatmap_contact.save(E2_entry_contact_heatmap)
E2_entry_heatmap_contact.display()
In [22]:
E3_entry_heatmap_contact = plot_entry_heatmap(
    df = func_scores_E3, 
    legend_title = "CHO-EFNB3 Entry", 
    null_color=config['background_color'],
    effect_color=config['effect_color'],
    entropy_color=config['entropy_color'],
    strokewidth_size=config['strokewidth_size'],
    specific_sites=config['contact_sites'],
    #specific_sites_name='Contact Sites',
    #contact_flag = True,
    #entropy_flag = True,
)
#E3_entry_heatmap_contact.save(E3_entry_contact_heatmap)
E3_entry_heatmap_contact.display()

combined_contact = alt.vconcat(E2_entry_heatmap_contact, E3_entry_heatmap_contact,title='Contact Sites')
combined_contact.display()
combined_contact.save(combined_entry_contact_heatmaps)

Show heatmap of different wildtype amino acid classes¶

In [23]:
hydrophobic_AA = ['A','V','L','I','M']
aromatic_AA = ['Y','W','F']
positive_AA = ['K','R','H']
negative_AA = ['E','D']
hydrophilic_AA = ['S','T','N','Q']

def find_aa_wildtype_sites(df,aa_type):
    aa_list = list(df[df['wildtype'].isin(aa_type)]['site'].unique())
    return aa_list

# Find sites where the WT are different classes of amino acids
hydrophobic_AA_list = find_aa_wildtype_sites(func_scores_E3,hydrophobic_AA)
aromatic_AA_list = find_aa_wildtype_sites(func_scores_E3,aromatic_AA)
positive_AA_list = find_aa_wildtype_sites(func_scores_E3,positive_AA)
negative_AA_list = find_aa_wildtype_sites(func_scores_E3,negative_AA)
hydrophilic_AA_list = find_aa_wildtype_sites(func_scores_E3,hydrophilic_AA)

all_AA = [hydrophobic_AA_list, aromatic_AA_list, positive_AA_list, negative_AA_list, hydrophilic_AA_list]
names = ['Hydrophobic', 'Aromatic', 'Positive', 'Negative', 'Hydrophilic']

charts_empty = []
# Use zip() to iterate through both lists concurrently
for aa_list, name in zip(all_AA, names):
        aa_properties_chart = plot_entry_heatmap(df=func_scores_E2,legend_title='CHO-EFNB2 Entry',specific_sites=aa_list,specific_sites_name=name)
        charts_empty.append(aa_properties_chart)

combined_chart = alt.vconcat(*charts_empty, spacing=3)
combined_chart.display()
combined_chart.save(entry_heatmap_by_wt_aa_property)

Show heatmap of sites that only have deleterious mutations¶

In [24]:
E2_intolerant_chart = plot_entry_heatmap(df=func_scores_E2,legend_title='CHO-EFNB2 Entry',specific_sites=intolerant_sites_E2,specific_sites_name='Highly constrained sites')
E2_intolerant_chart.display()

E3_intolerant_chart = plot_entry_heatmap(df=func_scores_E3,legend_title='CHO-EFNB3 Entry',specific_sites=intolerant_sites_E3,specific_sites_name='Highly constrained sites')
E3_intolerant_chart.display()

Plot heatmap of cysteine and n-linked glycosylation motifs¶

In [25]:
cysteine_neck = [146, 158, 162]
cysteine_1 = [189,601]
cysteine_2 = [216, 240]
cysteine_3 = [282,295]
cysteine_4 = [382, 395]
cysteine_5 = [387, 499]
cysteine_6 = [493, 503]
cysteine_7 = [565, 574]

cysteine = cysteine_neck + cysteine_1 + cysteine_2 + cysteine_3 + cysteine_4 + cysteine_5 + cysteine_6 + cysteine_7

n_linked = config['glycans']
stalk = list(range(96, 147))
neck = list(range(148,166))
linker = list(range(166,177))

df_list = [cysteine,n_linked,stalk,neck,linker]
df_names = ['Cysteines','N-linked Glycans','Stalk','Neck','Linker']

empty_charts = []
for aa_type, name in zip(df_list, df_names):
    E3_glycans_cysteines = plot_entry_heatmap(df=func_scores_E2,legend_title='CHO-EFNB2 Entry',specific_sites=aa_type,specific_sites_name=name)
    empty_charts.append(E3_glycans_cysteines)

combined_chart = alt.vconcat(*empty_charts, spacing=3)
combined_chart.display()

Check for potential neutral/beneficial glycosylation sites¶

In [26]:
def find_potential_glycan_sites(df):
    filtered = df[df["wildtype"].isin(["T", "S"])]
    matching_sites = []
    for index, row in filtered.iterrows():
        # Check for existence of site two numbers prior with 'N' mutant and positive effect
        prior_rows = df[
            (df["site"] == row["site"] - 2) & (df["mutant"] == "N") & (df["effect"] > 0)
        ]
        if not prior_rows.empty:
            matching_sites.append(row["site"])
    unique_list1 = list(set(matching_sites))
    unique_list1 = [x - 2 for x in unique_list1]

    filtered = df[df["wildtype"].isin(["N"])]
    matching_sites = []
    for index, row in filtered.iterrows():
        # Check for existence of site two numbers prior with 'N' mutant and positive effect
        prior_rows = df[
            (df["site"] == row["site"] + 2)
            & (df["mutant"].isin(["T", "S"]))
            & (df["effect"] > 0)
        ]
        if not prior_rows.empty:
            matching_sites.append(row["site"])
    unique_list2 = list(set(matching_sites))
    unique_list = unique_list1 + unique_list2
    unique_list.sort()
    print(f"The total number of potential PNLG sites are: {len(unique_list)}")
    return unique_list


PNLG = find_potential_glycan_sites(func_scores_E3)

surface_df = pd.read_csv(surface)
surface_df.rename(columns={"exposure_A": "Surface Exposure"}, inplace=True)
PNLG_SURFACE = surface_df[surface_df["site"].isin(PNLG)]
PNLG_SURFACE = list(
    PNLG_SURFACE[PNLG_SURFACE["Surface Exposure"] >= 25]["site"].unique()
)

print(f"\nThe surface exposed PNLG sites are: {PNLG_SURFACE}\n")

glycans = config["glycans"]
filtered_PNLG_SURFACE = [num for num in PNLG_SURFACE if num not in glycans]

print(filtered_PNLG_SURFACE)

print(len(filtered_PNLG_SURFACE))
The total number of potential PNLG sites are: 33

The surface exposed PNLG sites are: [159, 180, 191, 192, 275, 288, 306, 311, 326, 378, 383, 403, 417, 423, 473, 478, 518, 543, 554, 570, 600]

[180, 191, 192, 275, 288, 311, 326, 383, 403, 423, 473, 478, 518, 543, 554, 570, 600]
17
In [ ]: